package com.study.crypto.basic.signer;

import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.CryptoException;
import org.bouncycastle.crypto.CryptoServicesRegistrar;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.params.*;
import org.bouncycastle.crypto.signers.*;
import org.bouncycastle.math.ec.ECAlgorithms;
import org.bouncycastle.math.ec.ECMultiplier;
import org.bouncycastle.math.ec.ECPoint;

import java.math.BigInteger;

/**
 * sm2 签名验签。带入摘要值，直接计算签名值
 *
 * @author Songjin
 * @since 2021-01-03 14:59
 */
public class HashedSM2Signer extends SM2Signer {
    
    private final DSAKCalculator kCalculator = new RandomDSAKCalculator();
    private final Digest digest;
    private final DSAEncoding encoding;
    private ECDomainParameters ecParams;
    private ECPoint pubPoint;
    private ECKeyParameters ecKey;
    private byte[] hashData;
    
    public HashedSM2Signer() {
        this(StandardDSAEncoding.INSTANCE, new SM3Digest());
    }
    
    public HashedSM2Signer(Digest digest) {
        this(StandardDSAEncoding.INSTANCE, digest);
    }
    
    public HashedSM2Signer(DSAEncoding encoding) {
        this.encoding = encoding;
        this.digest = new SM3Digest();
    }
    
    public HashedSM2Signer(DSAEncoding encoding, Digest digest) {
        this.encoding = encoding;
        this.digest = digest;
    }
    
    public void initHashed(boolean forSigning, CipherParameters param) {
        CipherParameters baseParam;
        if (param instanceof ParametersWithID) {
            baseParam = ((ParametersWithID) param).getParameters();
        } else {
            baseParam = param;
        }
        
        if (forSigning) {
            if (baseParam instanceof ParametersWithRandom) {
                ParametersWithRandom rParam = (ParametersWithRandom) baseParam;
                this.ecKey = (ECKeyParameters) rParam.getParameters();
                this.ecParams = this.ecKey.getParameters();
                this.kCalculator.init(this.ecParams.getN(), rParam.getRandom());
            } else {
                this.ecKey = (ECKeyParameters) baseParam;
                this.ecParams = this.ecKey.getParameters();
                this.kCalculator.init(this.ecParams.getN(), CryptoServicesRegistrar.getSecureRandom());
            }
            
            this.pubPoint = this.createBasePointMultiplier()
                                .multiply(this.ecParams.getG(), ((ECPrivateKeyParameters) this.ecKey).getD())
                                .normalize();
        } else {
            this.ecKey = (ECKeyParameters) baseParam;
            this.ecParams = this.ecKey.getParameters();
            this.pubPoint = ((ECPublicKeyParameters) this.ecKey).getQ();
        }
    }
    
    public void updateHashed(byte[] b) {
        this.hashData = b;
    }
    
    public boolean verifyHashedSignature(byte[] signature) {
        try {
            BigInteger[] rs = this.encoding.decode(this.ecParams.getN(), signature);
            return this.verifyHashedSignature(rs[0], rs[1]);
        } catch (Exception var3) {
            return false;
        }
    }
    
    public byte[] generateHashedSignature() throws CryptoException {
        byte[] eHash = this.getHashData();
        BigInteger n = this.ecParams.getN();
        BigInteger e = this.calculateE(n, eHash);
        BigInteger d = ((ECPrivateKeyParameters) this.ecKey).getD();
        ECMultiplier basePointMultiplier = this.createBasePointMultiplier();
        
        while (true) {
            BigInteger r;
            BigInteger k;
            do {
                k = this.kCalculator.nextK();
                ECPoint p = basePointMultiplier.multiply(this.ecParams.getG(), k).normalize();
                r = e.add(p.getAffineXCoord().toBigInteger()).mod(n);
            } while (r.equals(ZERO));
            
            if (!r.add(k).equals(n)) {
                BigInteger dPlus1ModN = d.add(ONE).modInverse(n);
                BigInteger s          = k.subtract(r.multiply(d)).mod(n);
                s = dPlus1ModN.multiply(s).mod(n);
                if (!s.equals(ZERO)) {
                    try {
                        return this.encoding.encode(this.ecParams.getN(), r, s);
                    } catch (Exception var10) {
                        throw new CryptoException("unable to encode signature: " + var10.getMessage(), var10);
                    }
                }
            }
        }
    }
    
    private boolean verifyHashedSignature(BigInteger r, BigInteger s) {
        BigInteger n = this.ecParams.getN();
        boolean equals1 = r.compareTo(ONE) >= 0 && r.compareTo(n) < 0;
        if (!equals1) {
            return false;
        }
        boolean equals2 = s.compareTo(ONE) >= 0 && s.compareTo(n) < 0;
        if (!equals2) {
            return false;
        }
        
        byte[] eHash = this.getHashData();
        BigInteger e = this.calculateE(n, eHash);
        BigInteger t = r.add(s).mod(n);
        if (t.equals(ZERO)) {
            return false;
        }
        ECPoint q = ((ECPublicKeyParameters) this.ecKey).getQ();
        ECPoint x1y1 = ECAlgorithms.sumOfTwoMultiplies(this.ecParams.getG(), s, q, t).normalize();
        if (x1y1.isInfinity()) {
            return false;
        }
        BigInteger expectedR = e.add(x1y1.getAffineXCoord().toBigInteger()).mod(n);
        return expectedR.equals(r);
    }
    
    private byte[] getHashData() {
        byte[] result = this.hashData;
        this.hashData = null;
        return result;
    }
    
    
}
